import os
import re

def get_rank_id_from_ckpt_name(ckpt_file):
    """Get rank id from ckpt name."""
    ckpt_name = os.path.split(ckpt_file)[1]
    match = re.search(r'_rank_(\d+)', ckpt_name)
    if match:
        rank_id = int(match.group(1))
        return rank_id
    raise ValueError(f"Can't match rank id in checkpoint: {ckpt_file}. "
                     "Please ensure the format of the ckpt_file is {prefix}-{epoch}_{step}.ckpt. "
                     "for example, llama_7b_rank_0-3_2.ckpt.")


def replace_rank_id_in_ckpt_name(ckpt_file, dst_rank_id):
    """Replace rank id to dst_rank_id in ckpt name"""
    ckpt_name = os.path.split(ckpt_file)[1]
    ori_rank_id = get_rank_id_from_ckpt_name(ckpt_name)
    ckpt_name = ckpt_name.replace(f"_rank_{ori_rank_id}", f"_rank_{dst_rank_id}")
    return ckpt_name


def check_ckpt_file_name(ckpt_file):
    """Check ckpt name in the format of {prefix}-{epoch}_{step}.ckpt"""
    ckpt_name = os.path.split(ckpt_file)[1]
    pattern = r'^[^/]+-\d+_\d+\.ckpt$'
    match = re.match(pattern, ckpt_name)
    if match:
        return True
    return False
